{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33a38d29",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-14T14:52:10.373949Z",
     "start_time": "2023-11-14T14:52:09.182936Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd46809a-ae0b-4a49-89ee-7b346b3faabe",
   "metadata": {},
   "source": [
    "## basics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32da4ec6-bb98-4f5b-8390-3b6311853aa1",
   "metadata": {},
   "source": [
    "- `for name, child in module.named_children():`\n",
    "    - module 本质上维护了一个树结构；\n",
    "    - 参考 SyncBatchNorm（https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#SyncBatchNorm）\n",
    "        - `def convert_sync_batchnorm`：这个函数递归地处理所有的 module\n",
    "- PyTorch Module — quick reference：https://medium.com/howsofcoding/pytorch-module-quick-reference-b35d5a0f9a00"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb2f75ea-2d67-49b9-b3ef-1bd1093af56b",
   "metadata": {},
   "source": [
    "## `named_children` vs. `named_modules`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "206bf950-cc6e-4cf0-acb7-d30237390fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyModule, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(16, 32, 3, 1),\n",
    "            nn.ReLU(inplace=True)\n",
    "        )\n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Linear(32,10)\n",
    "        )\n",
    "        self.misc = nn.Linear(1, 1)\n",
    " \n",
    "    def forward(self,x):\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    " \n",
    "model = MyModule()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4c782cac-2a5b-4e87-968b-7556c6fe4557",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer1 -> Sequential(\n",
      "  (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (1): ReLU(inplace=True)\n",
      ")\n",
      "layer2 -> Sequential(\n",
      "  (0): Linear(in_features=32, out_features=10, bias=True)\n",
      ")\n",
      "misc -> Linear(in_features=1, out_features=1, bias=True)\n"
     ]
    }
   ],
   "source": [
    "for name, module in model.named_children():\n",
    "    print(name, '->', module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cf93557e-23fd-4cc1-9d20-735711922a1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " -> MyModule(\n",
      "  (layer1): Sequential(\n",
      "    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (1): ReLU(inplace=True)\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): Linear(in_features=32, out_features=10, bias=True)\n",
      "  )\n",
      "  (misc): Linear(in_features=1, out_features=1, bias=True)\n",
      ")\n",
      "layer1 -> Sequential(\n",
      "  (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (1): ReLU(inplace=True)\n",
      ")\n",
      "layer1.0 -> Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "layer1.1 -> ReLU(inplace=True)\n",
      "layer2 -> Sequential(\n",
      "  (0): Linear(in_features=32, out_features=10, bias=True)\n",
      ")\n",
      "layer2.0 -> Linear(in_features=32, out_features=10, bias=True)\n",
      "misc -> Linear(in_features=1, out_features=1, bias=True)\n"
     ]
    }
   ],
   "source": [
    "for name, module in model.named_modules():\n",
    "    print(name, '->', module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ee438f1",
   "metadata": {},
   "source": [
    "## `model.parameters()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "19c5fba6-2f0d-4f0d-9185-c5dcd3143d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Linear(5, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "29d04b31",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-14T14:52:57.383206Z",
     "start_time": "2023-11-14T14:52:57.369346Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Parameter containing:\n",
       " tensor([[-0.2888, -0.1333,  0.2464, -0.2843, -0.2926],\n",
       "         [-0.1179, -0.3022, -0.4258,  0.4373,  0.0226],\n",
       "         [ 0.1486, -0.1183, -0.3256, -0.3780, -0.4289]], requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor([-0.2695, -0.1645, -0.4121], requires_grad=True)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a1168a17",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-14T14:53:01.334448Z",
     "start_time": "2023-11-14T14:53:01.321081Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('weight',\n",
       "  Parameter containing:\n",
       "  tensor([[-0.2888, -0.1333,  0.2464, -0.2843, -0.2926],\n",
       "          [-0.1179, -0.3022, -0.4258,  0.4373,  0.0226],\n",
       "          [ 0.1486, -0.1183, -0.3256, -0.3780, -0.4289]], requires_grad=True)),\n",
       " ('bias',\n",
       "  Parameter containing:\n",
       "  tensor([-0.2695, -0.1645, -0.4121], requires_grad=True))]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(model.named_parameters())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
